{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 以经典的Iris数据集为例，取其中的部分数据来训练模型，然后测试模型能在多大程度上正确地判断（预测）余下的数据的标签。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 高斯贝叶斯算法，因为非常快速，而且不需要超参数，所以常常被首选用来做基础分类分析，当后续要对分析作出改进时，再尝试更换更好的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入Iris数据集。Scikit-Learn自带了这个经典的数据集\n",
    "from sklearn import datasets\n",
    "iris = datasets.load_iris()\n",
    "iris.data.shape, iris.target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 把数据集切分为训练集和测试集，可以人手切分，也可以使用Scikit-Learn提供的方便方法\n",
    "from sklearn.model_selection import train_test_split\n",
    "Xtrain, Xtest, ytrain, ytest = train_test_split(iris.data, iris.target,\n",
    "                                                test_size=0.20, random_state=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入高斯朴素贝叶斯模型的类对象\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "\n",
    "# 按照Scikit-Learn的常规步骤来做预测分析\n",
    "model = GaussianNB()\n",
    "model.fit(Xtrain, ytrain)\n",
    "ypredicted = model.predict(Xtest)\n",
    "\n",
    "# 打出预测好的结果\n",
    "ypredicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 接下来可以使用accuracy_score来查看预测结果与实际值之间的吻合程度\n",
    "from sklearn.metrics import accuracy_score\n",
    "accuracy_score(ytest, ypredicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试数据和预测结果的对比\n",
    "print(ytest)\n",
    "print(ypredicted)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
